{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------------------------------------------------------------\n",
    "# Numenta Platform for Intelligent Computing (NuPIC)\n",
    "# Copyright (C) 2019, Numenta, Inc.  Unless you have an agreement\n",
    "# with Numenta, Inc., for a separate license for this software code, the\n",
    "# following terms and conditions apply:\n",
    "#\n",
    "# This program is free software: you can redistribute it and/or modify\n",
    "# it under the terms of the GNU Affero Public License version 3 as\n",
    "# published by the Free Software Foundation.\n",
    "#\n",
    "# This program is distributed in the hope that it will be useful,\n",
    "# but WITHOUT ANY WARRANTY; without even the implied warranty of\n",
    "# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
    "# See the GNU Affero Public License for more details.\n",
    "#\n",
    "# You should have received a copy of the GNU Affero Public License\n",
    "# along with this program.  If not, see http://www.gnu.org/licenses.\n",
    "#\n",
    "# http://numenta.org/licenses/\n",
    "# ----------------------------------------------------------------------"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch\n",
    "!pip install torch torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "\n",
    "SEED = 18\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# Use GPU if available\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, loader, optimizer, criterion, post_batch_callback=None):\n",
    "    \"\"\"\n",
    "    Train the model using given dataset loader. \n",
    "    Called on every epoch.\n",
    "    :param model: pytorch model to be trained\n",
    "    :type model: torch.nn.Module\n",
    "    :param loader: dataloader configured for the epoch.\n",
    "    :type loader: :class:`torch.utils.data.DataLoader`\n",
    "    :param optimizer: Optimizer object used to train the model.\n",
    "    :type optimizer: :class:`torch.optim.Optimizer`\n",
    "    :param criterion: loss function to use\n",
    "    :type criterion: function\n",
    "    :param post_batch_callback: function(model) to call after every batch\n",
    "    :type post_batch_callback: function\n",
    "    \"\"\"\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(tqdm(loader, leave=False)):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if post_batch_callback is not None:\n",
    "            post_batch_callback(model)\n",
    "\n",
    "\n",
    "def test(model, loader, criterion):\n",
    "    \"\"\"\n",
    "    Evaluate pre-trained model using given dataset loader.\n",
    "    Called on every epoch.\n",
    "    :param model: Pretrained pytorch model\n",
    "    :type model: torch.nn.Module\n",
    "    :param loader: dataloader configured for the epoch.\n",
    "    :type loader: :class:`torch.utils.data.DataLoader`\n",
    "    :param criterion: loss function to use\n",
    "    :type criterion: function\n",
    "    :return: Dict with \"accuracy\", \"loss\" and \"total_correct\"\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    loss = 0\n",
    "    total_correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in tqdm(loader, leave=False):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            loss += criterion(output, target, reduction='sum').item() # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
    "            total_correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "    \n",
    "    return {\"accuracy\": total_correct / len(loader.dataset), \n",
    "            \"loss\": loss / len(loader.dataset), \n",
    "            \"total_correct\": total_correct}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training parameters\n",
    "LEARNING_RATE = 0.02\n",
    "LEARNING_RATE_GAMMA = 0.8\n",
    "MOMENTUM = 0.0\n",
    "EPOCHS = 15\n",
    "FIRST_EPOCH_BATCH_SIZE = 4\n",
    "TRAIN_BATCH_SIZE = 64\n",
    "TEST_BATCH_SIZE = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Sparse MNIST model\n",
    "\n",
    "There are 2 ways to create **nupic.torch** sparse models. You can import the models from **nupic.torch.models** or use pytorch's [torch.hub](https://pytorch.org/docs/stable/hub.html) API.\n",
    "\n",
    "In this example we will import the models. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNISTSparseCNN(\n",
      "  (cnn1): SparseWeights2d(\n",
      "    sparsity=0.4\n",
      "    (module): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))\n",
      "  )\n",
      "  (cnn1_maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (cnn1_kwinner): KWinners2d(channels=32, local=False, n=0, percent_on=0.1, boost_strength=1.5, boost_strength_factor=0.85, k_inference_factor=1.0, duty_cycle_period=1000)\n",
      "  (cnn2): SparseWeights2d(\n",
      "    sparsity=0.55\n",
      "    (module): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))\n",
      "  )\n",
      "  (cnn2_maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (cnn2_kwinner): KWinners2d(channels=64, local=False, n=0, percent_on=0.2, boost_strength=1.5, boost_strength_factor=0.85, k_inference_factor=1.0, duty_cycle_period=1000)\n",
      "  (flatten): Flatten()\n",
      "  (linear): SparseWeights(\n",
      "    sparsity=0.8\n",
      "    (module): Linear(in_features=1024, out_features=700, bias=True)\n",
      "  )\n",
      "  (linear_kwinner): KWinners(n=700, percent_on=0.2, boost_strength=1.5, boost_strength_factor=0.85, k_inference_factor=1.0, duty_cycle_period=1000)\n",
      "  (output): Linear(in_features=700, out_features=10, bias=True)\n",
      "  (softmax): LogSoftmax()\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from nupic.torch.models import MNISTSparseCNN\n",
    "# For this example we will use the default values. \n",
    "# See MNISTSparseCNN documentation for all possible parameters and their values.\n",
    "model = MNISTSparseCNN().to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load MNIST Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "normalize = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])\n",
    "train_dataset = datasets.MNIST('data', train=True, download=True, transform=normalize)\n",
    "test_dataset = datasets.MNIST('data', train=False, transform=normalize)\n",
    "\n",
    "# Configure data loaders\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=True)\n",
    "first_loader = torch.utils.data.DataLoader(train_dataset, batch_size=FIRST_EPOCH_BATCH_SIZE, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train\n",
    "On the first epoch we use smaller batch size to calculate the duty cycles used by the k-winner function. Once the duty cycles stabilize we can use larger batch sizes. Using the `post_batch`, we rezero the weights after every batch to keep the initial sparsity constant."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "from nupic.torch.modules import rezero_weights, update_boost_strength\n",
    "\n",
    "def post_batch(model):\n",
    "    model.apply(rezero_weights)\n",
    "\n",
    "sgd = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n",
    "lr_scheduler = optim.lr_scheduler.StepLR(sgd, step_size=1, gamma=LEARNING_RATE_GAMMA)\n",
    "train(model=model, loader=first_loader, optimizer=sgd, criterion=F.nll_loss, post_batch_callback=post_batch)\n",
    "lr_scheduler.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After each we apply the boost strength factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "model.apply(update_boost_strength)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test and print results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.9823, 'loss': 0.05553661346435547, 'total_correct': 9823}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(model=model, loader=test_loader, criterion=F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point the duty cycles should be stable and we can train on larger batch sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'accuracy': 0.99, 'loss': 0.03087631359100342, 'total_correct': 9900}\n",
      "{'accuracy': 0.9913, 'loss': 0.028131033515930177, 'total_correct': 9913}\n",
      "{'accuracy': 0.9912, 'loss': 0.02829796772003174, 'total_correct': 9912}\n",
      "{'accuracy': 0.9901, 'loss': 0.03012564697265625, 'total_correct': 9901}\n",
      "{'accuracy': 0.9909, 'loss': 0.0279187762260437, 'total_correct': 9909}\n",
      "{'accuracy': 0.9913, 'loss': 0.026879340744018553, 'total_correct': 9913}\n",
      "{'accuracy': 0.991, 'loss': 0.027533163070678712, 'total_correct': 9910}\n",
      "{'accuracy': 0.9915, 'loss': 0.026872401237487794, 'total_correct': 9915}\n",
      "{'accuracy': 0.9916, 'loss': 0.02652479610443115, 'total_correct': 9916}\n",
      "{'accuracy': 0.9915, 'loss': 0.02673978271484375, 'total_correct': 9915}\n",
      "{'accuracy': 0.9916, 'loss': 0.026827363300323485, 'total_correct': 9916}\n",
      "{'accuracy': 0.9914, 'loss': 0.02654646396636963, 'total_correct': 9914}\n",
      "{'accuracy': 0.9916, 'loss': 0.026357748794555665, 'total_correct': 9916}\n",
      "{'accuracy': 0.9917, 'loss': 0.026316686820983887, 'total_correct': 9917}\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, EPOCHS):\n",
    "    train(model=model, loader=train_loader, optimizer=sgd, criterion=F.nll_loss, post_batch_callback=post_batch)\n",
    "    lr_scheduler.step()\n",
    "    model.apply(rezero_weights)\n",
    "    model.apply(update_boost_strength)\n",
    "    results = test(model=model, loader=test_loader, criterion=F.nll_loss)\n",
    "    print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Noise\n",
    "Add noise to the input and check the test accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RandomNoise(object):\n",
    "    \"\"\"\n",
    "    An image transform that adds noise to random pixels in the image.\n",
    "    \"\"\"\n",
    "    def __init__(self, noise_level=0.0, white_value=0.1307 + 2*0.3081):\n",
    "        \"\"\"\n",
    "        :param noise_level:\n",
    "          From 0 to 1. For each pixel, set its value to white_value with this\n",
    "          probability. Suggested white_value is 'mean + 2*stdev'\n",
    "        \"\"\"\n",
    "        self.noise_level = noise_level\n",
    "        self.white_value = white_value\n",
    "\n",
    "    def __call__(self, image):\n",
    "        a = image.view(-1)\n",
    "        num_noise_bits = int(a.shape[0] * self.noise_level)\n",
    "        noise = np.random.permutation(a.shape[0])[0:num_noise_bits]\n",
    "        a[noise] = self.white_value\n",
    "        return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0 : {'accuracy': 0.9917, 'loss': 0.02631667709350586, 'total_correct': 9917}\n",
      "0.05 : {'accuracy': 0.9899, 'loss': 0.032843858909606935, 'total_correct': 9899}\n",
      "0.1 : {'accuracy': 0.9855, 'loss': 0.0439798023223877, 'total_correct': 9855}\n",
      "0.15 : {'accuracy': 0.9829, 'loss': 0.05501823501586914, 'total_correct': 9829}\n",
      "0.2 : {'accuracy': 0.976, 'loss': 0.07691587409973144, 'total_correct': 9760}\n",
      "0.25 : {'accuracy': 0.9655, 'loss': 0.10933124389648438, 'total_correct': 9655}\n",
      "0.3 : {'accuracy': 0.9499, 'loss': 0.15266261291503908, 'total_correct': 9499}\n",
      "0.35 : {'accuracy': 0.9305, 'loss': 0.21989019317626954, 'total_correct': 9305}\n",
      "0.4 : {'accuracy': 0.9015, 'loss': 0.30402026062011717, 'total_correct': 9015}\n",
      "0.45 : {'accuracy': 0.8543, 'loss': 0.4642761535644531, 'total_correct': 8543}\n",
      "0.5 : {'accuracy': 0.7951, 'loss': 0.6449000061035156, 'total_correct': 7951}\n"
     ]
    }
   ],
   "source": [
    "noise_score = 0\n",
    "for noise in [0.0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:\n",
    "    noise_transform = transforms.Compose([transforms.ToTensor(), RandomNoise(noise), \n",
    "                                      transforms.Normalize((0.1307,), (0.3081,))])\n",
    "    noise_dataset = datasets.MNIST('data', train=False, transform=noise_transform)\n",
    "    noise_loader = torch.utils.data.DataLoader(noise_dataset, \n",
    "                                               batch_size=TEST_BATCH_SIZE, \n",
    "                                               shuffle=True)\n",
    "\n",
    "    results = test(model=model, loader=noise_loader, criterion=F.nll_loss)\n",
    "    noise_score += results[\"total_correct\"]\n",
    "    print(noise, \":\", results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "noise_score: 103228\n"
     ]
    }
   ],
   "source": [
    "print(\"noise_score:\", noise_score)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
